from mynumpy import *

def dist_params(w):
    assert(len(w)==2)
    loc = w[0]
    scale = exp(w[1])
    return loc,scale

class estimator_naive:

    def __init__(self,target,basedist):
        self.basedist  = basedist
        self.target    = target
        self.label     = 'base'
        self.omega_dim = 1
        self.w_dim     = 2

    def sample_omega(self,num):
        return np.random.rand(num)

    def sample_z(self,omega,w):
        return self.basedist.ppf(omega,*dist_params(w))
    
    def sample_zs(self,omega,w):
        return self.sample_z(omega,w)

    def z_to_omega(self,z,w):
        return self.basedist.cdf(z,*dist_params(w))

    def logR(self,omega,w):
        z = self.sample_z(omega,w)
        #print('z',z)
        #p1 = self.target.logp(z)
        #p2 = self.basedist.logpdf(z,*dist_params(w))
        #return p1-p2
        return self.target.logp(z) - self.basedist.logpdf(z,*dist_params(w))

    def a(self,z,w):
        return self.basedist.pdf(z,*dist_params(w))